{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"http://developer.download.nvidia.com/notebooks/dlsw-notebooks/tensorrt_torchtrt_efficientnet/nvidia_logo.png\" width=\"90px\">\n",
    "\n",
    "# PySpark LLM Inference: Qwen-2.5 Text Summarization\n",
    "\n",
    "In this notebook, we demonstrate distributed batch inference with [Qwen-2.5](https://huggingface.co/Qwen/Qwen2.5-7B-Instruct), using open weights on Huggingface.\n",
    "\n",
    "The Qwen-2.5-7b-instruct is an instruction-fine-tuned version of the Qwen-2.5-7b base model. We'll show how to use the model to perform text summarization.\n",
    "\n",
    "**Note:** Running this model on GPU with 16-bit precision requires **~16GB** of GPU RAM. Make sure your instances have sufficient GPU capacity."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Manually enable Huggingface tokenizer parallelism to avoid disabling with PySpark parallelism.\n",
    "# See (https://github.com/huggingface/transformers/issues/5486) for more info. \n",
    "import os\n",
    "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"true\"\n",
    "\n",
    "# vLLM does CUDA init at import time. Forking will try to re-initialize CUDA if vLLM was imported before and throw an error.\n",
    "os.environ[\"VLLM_WORKER_MULTIPROC_METHOD\"] = \"spawn\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Check the cluster environment to handle any platform-specific configurations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "on_databricks = os.environ.get(\"DATABRICKS_RUNTIME_VERSION\", False)\n",
    "on_dataproc = os.environ.get(\"DATAPROC_IMAGE_VERSION\", False)\n",
    "on_standalone = not (on_databricks or on_dataproc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# For cloud environments, load the model to the distributed file system.\n",
    "if on_databricks:\n",
    "    models_dir = \"/dbfs/FileStore/spark-dl-models\"\n",
    "    dbutils.fs.mkdirs(\"/FileStore/spark-dl-models\")\n",
    "    model_path = f\"{models_dir}/qwen-2.5-7b\"\n",
    "elif on_dataproc:\n",
    "    models_dir = \"/mnt/gcs/spark-dl-models\"\n",
    "    os.mkdir(models_dir) if not os.path.exists(models_dir) else None\n",
    "    model_path = f\"{models_dir}/qwen-2.5-7b\"\n",
    "else:\n",
    "    model_path = os.path.abspath(\"qwen-2.5-7b\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Download the model from huggingface hub."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "eeb0daf2bd7948bebd94ce2a9a5a01b8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Fetching 14 files:   0%|          | 0/14 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from huggingface_hub import snapshot_download\n",
    "\n",
    "model_path = snapshot_download(\n",
    "    repo_id=\"Qwen/Qwen2.5-7B-Instruct\",\n",
    "    local_dir=model_path\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "There are currently some hanging issues with vLLM's `torch.compile` on Databricks, which we are working to resolve. For now we will enforce eager mode on Databricks, which disables compilation at some performance cost."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "enforce_eager = True if on_databricks else False"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Warmup: Running locally\n",
    "\n",
    "**Note**: If the driver node does not have sufficient GPU capacity, proceed to the PySpark section."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from vllm import LLM, SamplingParams\n",
    "from transformers import AutoTokenizer\n",
    "\n",
    "sampling_params = SamplingParams(temperature=0.7, top_p=0.8, repetition_penalty=1.05, max_tokens=128)\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side=\"left\")\n",
    "llm = LLM(model=model_path, gpu_memory_utilization=0.95, max_model_len=6600, enforce_eager=enforce_eager)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "system_prompt = {\n",
    "    \"role\": \"system\",\n",
    "    \"content\": \"You are a knowledgeable AI assistant that provides accurate answers to questions.\"\n",
    "}\n",
    "\n",
    "queries = [\n",
    "    \"What does CUDA stand for?\",\n",
    "    \"In one sentence, what's the difference between a CPU and a GPU?\",\n",
    "    \"What's the hottest planet in the solar system?\"\n",
    "]\n",
    "\n",
    "prompts = [\n",
    "    [\n",
    "        system_prompt,\n",
    "        {\"role\": \"user\", \"content\": query}\n",
    "    ] for query in queries\n",
    "]\n",
    "\n",
    "text = tokenizer.apply_chat_template(\n",
    "    prompts,\n",
    "    tokenize=False,\n",
    "    add_generation_prompt=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processed prompts: 100%|██████████| 3/3 [00:01<00:00,  1.76it/s, est. speed input: 63.83 toks/s, output: 100.14 toks/s]\n"
     ]
    }
   ],
   "source": [
    "outputs = llm.generate(text, sampling_params=sampling_params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Q: What does CUDA stand for?\n",
      "A: CUDA stands for Compute Unified Device Architecture. It is a parallel computing platform and application programming interface (API) model created by NVIDIA. CUDA allows software developers to use a CUDA-enabled graphics processing unit (GPU) for general purpose processing.\n",
      "\n",
      "Q: In one sentence, what's the difference between a CPU and a GPU?\n",
      "A: A CPU (Central Processing Unit) is designed for general-purpose processing and managing the overall operations of a computer, while a GPU (Graphics Processing Unit) is specialized for parallel processing tasks, particularly those related to rendering graphics and accelerating machine learning tasks.\n",
      "\n",
      "Q: What's the hottest planet in the solar system?\n",
      "A: The hottest planet in the solar system is Venus. Despite Mercury being closer to the Sun, Venus has a thick atmosphere that traps heat in a runaway version of the greenhouse effect, creating a much hotter surface temperature than Mercury. The average surface temperature on Venus is about 462°C (864°F), making it the hottest planet in our solar system.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "for q, o in zip(queries, outputs):\n",
    "    print(f\"Q: {q}\")\n",
    "    print(f\"A: {o.outputs[0].text}\\n\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Unload the model to free up the GPU for the PySpark section."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import contextlib\n",
    "import gc\n",
    "import torch\n",
    "from vllm.distributed import destroy_model_parallel, destroy_distributed_environment\n",
    "\n",
    "def cleanup():\n",
    "    destroy_model_parallel()\n",
    "    destroy_distributed_environment()\n",
    "    with contextlib.suppress(AssertionError):\n",
    "        torch.distributed.destroy_process_group()\n",
    "    gc.collect()\n",
    "    torch.cuda.empty_cache()\n",
    "\n",
    "del llm\n",
    "cleanup()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## PySpark"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from pyspark.sql.types import *\n",
    "from pyspark import SparkConf\n",
    "from pyspark.sql import SparkSession\n",
    "from pyspark.sql.functions import pandas_udf, col, struct, length, lit, concat\n",
    "from pyspark.ml.functions import predict_batch_udf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[2025-03-24 11:37:46] INFO config.py:54: PyTorch version 2.6.0 available.\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import datasets\n",
    "from datasets import load_dataset\n",
    "datasets.disable_progress_bars()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Create Spark Session\n",
    "\n",
    "For local standalone clusters, we'll connect to the cluster and create the Spark Session.  \n",
    "For CSP environments, Spark will either be preconfigured (Databricks) or we'll need to create the Spark Session (Dataproc)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "25/03/24 11:37:47 WARN Utils: Your hostname, cb4ae00-lcedt resolves to a loopback address: 127.0.1.1; using 10.110.47.100 instead (on interface eno1)\n",
      "25/03/24 11:37:47 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address\n",
      "Setting default log level to \"WARN\".\n",
      "To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n",
      "25/03/24 11:37:47 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n"
     ]
    }
   ],
   "source": [
    "conf = SparkConf()\n",
    "\n",
    "if 'spark' not in globals():\n",
    "    if on_standalone:\n",
    "        import socket\n",
    "        conda_env = os.environ.get(\"CONDA_PREFIX\")\n",
    "        hostname = socket.gethostname()\n",
    "        conf.setMaster(f\"spark://{hostname}:7077\")\n",
    "        conf.set(\"spark.pyspark.python\", f\"{conda_env}/bin/python\")\n",
    "        conf.set(\"spark.pyspark.driver.python\", f\"{conda_env}/bin/python\")\n",
    "\n",
    "    conf.set(\"spark.executor.cores\", \"8\")\n",
    "    conf.set(\"spark.task.maxFailures\", \"1\")\n",
    "    conf.set(\"spark.task.resource.gpu.amount\", \"0.125\")\n",
    "    conf.set(\"spark.executor.resource.gpu.amount\", \"1\")\n",
    "    conf.set(\"spark.sql.execution.arrow.pyspark.enabled\", \"true\")\n",
    "    conf.set(\"spark.python.worker.reuse\", \"true\")\n",
    "\n",
    "spark = SparkSession.builder.appName(\"spark-dl-examples\").config(conf=conf).getOrCreate()\n",
    "sc = spark.sparkContext"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Load and Preprocess DataFrame\n",
    "\n",
    "Load the first 500 samples of the [ML ArXiv dataset](https://huggingface.co/datasets/CShorten/ML-ArXiv-Papers) from Huggingface and store in a Spark Dataframe."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "ml_arxiv_dataset = load_dataset(\"CShorten/ML-ArXiv-Papers\", split=\"train\", streaming=True)\n",
    "ml_arxiv_pds = pd.Series([sample[\"abstract\"] for sample in ml_arxiv_dataset.take(500)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = spark.createDataFrame(ml_arxiv_pds, schema=StringType())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+----------------------------------------------------------------------------------------------------+\n",
      "|                                                                                               value|\n",
      "+----------------------------------------------------------------------------------------------------+\n",
      "|  The problem of statistical learning is to construct a predictor of a random\\nvariable $Y$ as a ...|\n",
      "|  In a sensor network, in practice, the communication among sensors is subject\\nto:(1) errors or ...|\n",
      "|  The on-line shortest path problem is considered under various models of\\npartial monitoring. Gi...|\n",
      "|  Ordinal regression is an important type of learning, which has properties of\\nboth classificati...|\n",
      "|  This paper uncovers and explores the close relationship between Monte Carlo\\nOptimization of a ...|\n",
      "+----------------------------------------------------------------------------------------------------+\n",
      "only showing top 5 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "df.show(5, truncate=100)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Format each sample into the chat template, including a system prompt to guide generation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "system_prompt = '''You are a knowledgeable AI assistant. Your job is to create a 1 sentence summary \n",
    "of a research abstract that captures the main objective, methodology, and key findings, using clear \n",
    "language while preserving technical accuracy and quantitative results.'''\n",
    "\n",
    "df = df.select(\n",
    "    concat(\n",
    "        lit(\"<|im_start|>system\\n\"),\n",
    "        lit(system_prompt),\n",
    "        lit(\"<|im_end|>\\n<|im_start|>user\\n\"),\n",
    "        col(\"value\"),\n",
    "        lit(\"<|im_end|>\\n<|im_start|>assistant\\n\")\n",
    "    ).alias(\"prompt\")\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<|im_start|>system\n",
      "You are a knowledgeable AI assistant. Your job is to create a 1 sentence summary \n",
      "of a research abstract that captures the main objective, methodology, and key findings, using clear \n",
      "language while preserving technical accuracy and quantitative results.<|im_end|>\n",
      "<|im_start|>user\n",
      "  The problem of statistical learning is to construct a predictor of a random\n",
      "variable $Y$ as a function of a related random variable $X$ on the basis of an\n",
      "i.i.d. training sample from the joint distribution of $(X,Y)$. Allowable\n",
      "predictors are drawn from some specified class, and the goal is to approach\n",
      "asymptotically the performance (expected loss) of the best predictor in the\n",
      "class. We consider the setting in which one has perfect observation of the\n",
      "$X$-part of the sample, while the $Y$-part has to be communicated at some\n",
      "finite bit rate. The encoding of the $Y$-values is allowed to depend on the\n",
      "$X$-values. Under suitable regularity conditions on the admissible predictors,\n",
      "the underlying family of probability distributions and the loss function, we\n",
      "give an information-theoretic characterization of achievable predictor\n",
      "performance in terms of conditional distortion-rate functions. The ideas are\n",
      "illustrated on the example of nonparametric regression in Gaussian noise.\n",
      "<|im_end|>\n",
      "<|im_start|>assistant\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(df.take(1)[0].prompt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_path = \"spark-dl-datasets/arxiv_abstracts\"\n",
    "if on_databricks:\n",
    "    dbutils.fs.mkdirs(\"/FileStore/spark-dl-datasets\")\n",
    "    data_path = \"dbfs:/FileStore/\" + data_path\n",
    "\n",
    "df.write.mode(\"overwrite\").parquet(data_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Using vLLM Server\n",
    "In this section, we demonstrate integration with [vLLM Serving](https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html), an open-source server with an OpenAI-compatible completions endpoint for LLMs.  \n",
    "\n",
    "The process looks like this:\n",
    "- Distribute a server startup task across the Spark cluster, instructing each node to launch a vLLM server process.\n",
    "- Define a vLLM inference function, which sends inference request to the local server on a given node.\n",
    "- Wrap the vLLM inference function in a predict_batch_udf to launch parallel inference requests using Spark.\n",
    "- Finally, distribute a shutdown signal to terminate the vLLM server processes on each node.\n",
    "\n",
    "<img src=\"../images/spark-server.png\" alt=\"drawing\" width=\"700\"/>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "from functools import partial"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Import the helper class from server_utils.py:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "sc.addPyFile(\"server_utils.py\")\n",
    "\n",
    "from server_utils import VLLMServerManager"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Start vLLM servers"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The `VLLMServerManager` will handle the lifecycle of vLLM server instances across the Spark cluster:\n",
    "- Find available ports for HTTP\n",
    "- Deploy a server on each node via stage-level scheduling\n",
    "- Gracefully shutdown servers across nodes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                \r"
     ]
    }
   ],
   "source": [
    "model_name = \"qwen-2.5-7b\"\n",
    "server_manager = VLLMServerManager(model_name=model_name, model_path=model_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can pass any of the supported [vLLM serve CLI arguments](https://docs.vllm.ai/en/stable/serving/openai_compatible_server.html#vllm-serve) as key-word arguments when starting the servers. Note that this can take some time, as it includes loading the model from disk, Torch compilation, and capturing CUDA graphs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[2025-03-24 11:37:57] INFO server_utils.py:359: Requesting stage-level resources: (cores=5, gpu=1.0)\n",
      "[2025-03-24 11:37:57] INFO server_utils.py:390: Starting 1 VLLM servers.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                \r"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'cb4ae00-lcedt': (4022579, [7000])}"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "server_manager.start_servers(gpu_memory_utilization=0.95,\n",
    "                             max_model_len=6600,\n",
    "                             task=\"generate\",\n",
    "                             enforce_eager=enforce_eager,\n",
    "                             wait_retries=60)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Define client function"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Get the hostname -> url mapping from the server manager:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "host_to_http_url = server_manager.host_to_http_url"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Define the vLLM inference function, which returns a predict function for batch inference through the server:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "def vllm_fn(model_name, host_to_url):\n",
    "    import socket\n",
    "    import numpy as np\n",
    "    import requests\n",
    "\n",
    "    url = host_to_url[socket.gethostname()]\n",
    "    \n",
    "    def predict(inputs):\n",
    "        response = requests.post(\n",
    "            f\"{url}/v1/completions\",\n",
    "            json={\n",
    "                \"model\": model_name,\n",
    "                \"prompt\": inputs.tolist(),\n",
    "                \"max_tokens\": 128,\n",
    "                \"temperature\": 0.7,\n",
    "                \"top_p\": 0.8,\n",
    "                \"repetition_penalty\": 1.05,\n",
    "            }\n",
    "        )\n",
    "        return np.array([r[\"text\"] for r in response.json()[\"choices\"]])\n",
    "    \n",
    "    return predict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "generate = predict_batch_udf(partial(vllm_fn, model_name=model_name, host_to_url=host_to_http_url),\n",
    "                             return_type=StringType(),\n",
    "                             batch_size=32)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Load DataFrame\n",
    "\n",
    "We'll parallelize over a small set of prompts for demonstration."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = spark.read.parquet(data_path).limit(256).repartition(8)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Run Inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Stage 11:==================================================>       (7 + 1) / 8]\r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 7.53 ms, sys: 2.19 ms, total: 9.72 ms\n",
      "Wall time: 13.9 s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                \r"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "# first pass caches model/fn and does JIT compilation\n",
    "preds = df.withColumn(\"outputs\", generate(col(\"prompt\")))\n",
    "results = preds.collect()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Stage 17:===========================================>              (6 + 2) / 8]\r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 10.7 ms, sys: 3.65 ms, total: 14.3 ms\n",
      "Wall time: 6.26 s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                \r"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "preds = df.withColumn(\"outputs\", generate(col(\"prompt\")))\n",
    "results = preds.collect()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Sample output:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Q: <|im_start|>system\n",
      "You are a knowledgeable AI assistant. Your job is to create a 1 sentence summary \n",
      "of a research abstract that captures the main objective, methodology, and key findings, using clear \n",
      "language while preserving technical accuracy and quantitative results.<|im_end|>\n",
      "<|im_start|>user\n",
      "  Images can be segmented by first using a classifier to predict an affinity\n",
      "graph that reflects the degree to which image pixels must be grouped together\n",
      "and then partitioning the graph to yield a segmentation. Machine learning has\n",
      "been applied to the affinity classifier to produce affinity graphs that are\n",
      "good in the sense of minimizing edge misclassification rates. However, this\n",
      "error measure is only indirectly related to the quality of segmentations\n",
      "produced by ultimately partitioning the affinity graph. We present the first\n",
      "machine learning algorithm for training a classifier to produce affinity graphs\n",
      "that are good in the sense of producing segmentations that directly minimize\n",
      "the Rand index, a well known segmentation performance measure. The Rand index\n",
      "measures segmentation performance by quantifying the classification of the\n",
      "connectivity of image pixel pairs after segmentation. By using the simple graph\n",
      "partitioning algorithm of finding the connected components of the thresholded\n",
      "affinity graph, we are able to train an affinity classifier to directly\n",
      "minimize the Rand index of segmentations resulting from the graph partitioning.\n",
      "Our learning algorithm corresponds to the learning of maximin affinities\n",
      "between image pixel pairs, which are predictive of the pixel-pair connectivity.\n",
      "<|im_end|>\n",
      "<|im_start|>assistant\n",
      " \n",
      "\n",
      "A: The research presents a machine learning algorithm that trains an affinity classifier to directly minimize the Rand index of image segmentations by producing affinity graphs optimized for pixel-pair connectivity, using a simple graph partitioning method. \n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(f\"Q: {results[0].prompt} \\n\")\n",
    "print(f\"A: {results[0].outputs} \\n\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Shut down server on each executor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[2025-03-24 11:38:49] INFO server_utils.py:359: Requesting stage-level resources: (cores=5, gpu=1.0)\n",
      "[2025-03-24 11:38:50] INFO server_utils.py:447: Successfully stopped 1 VLLM servers.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[True]"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "server_manager.stop_servers()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "if not on_databricks: # on databricks, spark.stop() puts the cluster in a bad state\n",
    "    spark.stop()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "spark-dl-vllm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
